{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nltk\n",
    "from pathlib import Path\n",
    "from tabulate import tabulate\n",
    "from collections import namedtuple\n",
    "from pyAutoSummarizer.base import summarization\n",
    "from functools import partial\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Function: Print Lines\n",
    "def print_in_lines(text, line_length):\n",
    "    words        = text.split()\n",
    "    lines        = []\n",
    "    current_line = ''\n",
    "    for word in words:\n",
    "        if len(current_line + word) <= line_length:\n",
    "            current_line = current_line + word + ' '\n",
    "        else:\n",
    "            lines.append(current_line.strip())\n",
    "            current_line = word + ' '\n",
    "    lines.append(current_line.strip())\n",
    "    for line in lines:\n",
    "        print(line)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# bigsurvey\n",
    "\n",
    "Pair = namedtuple('Pair', ['source', 'reference'])\n",
    "\n",
    "data_pairs = []\n",
    "\n",
    "source_sentences_list, target_sentences_list = [], []\n",
    "\n",
    "for source, reference in zip(\n",
    "    Path(r'D:\\实验室\\2024_03_05课程大纲\\数据\\bigsurvey\\test.src.txt').read_text().split('\\n'), \n",
    "    Path(r'D:\\实验室\\2024_03_05课程大纲\\数据\\bigsurvey\\test.tgt.txt').read_text().split('\\n')\n",
    "):\n",
    "    # source_sentences = nltk.tokenize.sent_tokenize(source)\n",
    "    # target_sentences = nltk.tokenize.sent_tokenize(reference)\n",
    "    data_pairs.append(Pair(source, reference))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/452 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 452/452 [00:26<00:00, 17.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "average word length: 12042.444690265487\n",
      "average sentence length: 455.7278761061947\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# 统计一下词的数量\n",
    "word_lens = []\n",
    "sent_lens = []\n",
    "for source, reference in tqdm(data_pairs):\n",
    "    words = nltk.word_tokenize(source)\n",
    "    word_lens.append(len(words))\n",
    "    sentences = nltk.sent_tokenize(source)\n",
    "    sent_lens.append(len(sentences))\n",
    "\n",
    "print(f'average word length: {sum(word_lens) / len(word_lens)}')\n",
    "print(f'average sentence length: {sum(sent_lens) / len(sent_lens)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "textrank_n_50: 100%|██████████| 452/452 [41:02<00:00,  5.45s/it] \n"
     ]
    }
   ],
   "source": [
    "parameters = { 'stop_words':        ['en'],\n",
    "    'n_words':           -1,\n",
    "    'n_chars':           -1,\n",
    "    'lowercase':         True,\n",
    "    'rmv_accents':       True,\n",
    "    'rmv_special_chars': True,\n",
    "    'rmv_numbers':       False,\n",
    "    'rmv_custom_words':  [],\n",
    "    'verbose':           False\n",
    "}\n",
    "\n",
    "model_funcs = {\n",
    "    # 'textrank_n_50': (\n",
    "    #     partial(summarization.summ_text_rank, iteration = 1000, D = 0.85, model = 'all-MiniLM-L6-v2'),\n",
    "    #     partial(summarization.show_summary, n=50)\n",
    "    # ),\n",
    "    'lexrank_n_50': (\n",
    "        partial(summarization.summ_lex_rank, iteration = 1000, D = 0.85),\n",
    "        partial(summarization.show_summary, n=50)\n",
    "    ),\n",
    "    'lsa_n_50': (\n",
    "        partial(summarization.summ_ext_LSA, embeddings = False, model = 'all-MiniLM-L6-v2'),\n",
    "        partial(summarization.show_summary, n=50)\n",
    "    ),\n",
    "    'kl_n_50': (\n",
    "        partial(summarization.summ_ext_KL, n=3),\n",
    "        partial(summarization.show_summary, n=50),\n",
    "    ),\n",
    "    # # 最长1024，TODO: 层次式方法\n",
    "    # 'bart_len_1250': (\n",
    "    #     partial(summarization.summ_ext_bart, model = 'facebook/bart-large-cnn', max_len = 250),\n",
    "    #     lambda r: r,\n",
    "    # ),\n",
    "    # # TODO: 对比长上下文的模型\n",
    "}\n",
    "\n",
    "output_dir = Path(r'D:\\实验室\\2024_03_05课程大纲\\数据\\bigsurvey\\output')\n",
    "\n",
    "for model_name, (exec_func, summ_func) in model_funcs.items():\n",
    "    for source, reference in tqdm(data_pairs, desc=model_name):\n",
    "        smr = summarization(source, **parameters)\n",
    "        rank = exec_func(smr)\n",
    "        summary = summ_func(smr, rank)\n",
    "        generated_summary = smr.show_summary(rank, n = 50)\n",
    "\n",
    "        file = Path(output_dir / model_name).open('a')\n",
    "        file.write(generated_summary + '\\n')\n",
    "        file.close()\n",
    "\n",
    "# 确定生成的长度\n",
    "# for source, reference in tqdm(data_pairs):\n",
    "#     # Create Instance\n",
    "#     # print(source)\n",
    "#     smr = summarization(source, **parameters)\n",
    "#     rank = smr.summ_text_rank(iteration = 1000, D = 0.85, model = 'all-MiniLM-L6-v2')\n",
    "#     generated_summary = smr.show_summary(rank, n = 3)\n",
    "#     # print_in_lines(generated_summary, line_length = 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "h:\\Anaconda\\envs\\wsln\\Lib\\site-packages\\sentence_transformers\\cross_encoder\\CrossEncoder.py:13: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from tqdm.autonotebook import tqdm, trange\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "If your work environment makes you constantly upset because it's hostile or you just don't like your\n",
      "job, you may need to find yourself new employment where you can be happier If you can't garner\n",
      "interest in any part of your life, that could be a sign of depression If you're really unhappy as a\n",
      "stay-at-home parent, maybe it's time to start thinking about alternatives, such as going to work\n",
      "If your work environment makes you constantly upset because it's hostile or you just don't like your\n",
      "job, you may need to find yourself new employment where you can be happier If you're really unhappy\n",
      "as a stay-at-home parent, maybe it's time to start thinking about alternatives, such as going to\n",
      "work Start a job hunt now so that you can find something you love\n",
      "Other signs may include general sadness, anxiety, fatigue, brain forgetfulness, and irritability\n",
      "Maybe your workload is just too much If you can't garner interest in any part of your life, that\n",
      "could be a sign of depression\n",
      "Maybe your workload is just too much If you can't garner interest in any part of your life, that\n",
      "could be a sign of depression Other signs may include general sadness, anxiety, fatigue, brain\n",
      "forgetfulness, and irritability\n",
      "If your work environment makes you constantly upset because it's hostile or you just don't like your\n",
      "job, you may need to find yourself new employment where you can be happier. If you can't garner\n",
      "interest in any part of your life, that could be a sign of depression. See your doctor if you're\n",
      "exhibiting these symptoms.\n"
     ]
    }
   ],
   "source": [
    "# baselines\n",
    "# Extractive Summarization: textrank, lexrank, lsa, kl-sum, BART, T5， textteaser\n",
    "# Abstractive Summarization: Large Language Model, PEGASUS\n",
    "\n",
    "from pyAutoSummarizer.base import summarization\n",
    "\n",
    "text = \"\"\" \n",
    "        If your work environment makes you constantly upset because it's hostile or you just don't like your job, you may need to find yourself new employment where you can be happier. \n",
    "        Maybe your workload is just too much. Start a job hunt now so that you can find something you love. If you're unhappy in your degree program, maybe you need to change fields, \n",
    "        or maybe you need to try something different altogether. If you're really unhappy as a stay-at-home parent, maybe it's time to start thinking about alternatives, such as going to work. \n",
    "        If you can't garner interest in any part of your life, that could be a sign of depression. Ask your doctor for more information. Other signs may include general sadness, anxiety, fatigue, \n",
    "        brain forgetfulness, and irritability. See your doctor if you're exhibiting these symptoms.\n",
    "\n",
    "       \"\"\"\n",
    "text = text.replace('\\n', '').replace('        ', '')\n",
    "\n",
    "# Load Reference Summary (Human Made Summary for Benchmark)\n",
    "reference_summary = \"\"\"\n",
    "                        Look at your work environment. Check for declining interest in other parts of your life. See if you exhibit other signs of depression.\n",
    "                    \"\"\"\n",
    "reference_summary = reference_summary.replace('\\n', '').replace('        ', '')\n",
    "\n",
    "# Load Parameters\n",
    "parameters = { 'stop_words':        ['en'],\n",
    "               'n_words':           -1,\n",
    "               'n_chars':           -1,\n",
    "               'lowercase':         True,\n",
    "               'rmv_accents':       True,\n",
    "               'rmv_special_chars': True,\n",
    "               'rmv_numbers':       False,\n",
    "               'rmv_custom_words':  [],\n",
    "               'verbose':           False\n",
    "              }\n",
    "\n",
    "# Create Instance\n",
    "smr = summarization(text, **parameters)\n",
    "\n",
    "# Rank Sentences\n",
    "rank = smr.summ_text_rank(iteration = 1000, D = 0.85, model = 'all-MiniLM-L6-v2')\n",
    "# Show Summary\n",
    "generated_summary = smr.show_summary(rank, n = 3)\n",
    "\n",
    "# Print Summary - TextRank\n",
    "print_in_lines(generated_summary, line_length = 100)\n",
    "\n",
    "\n",
    "rank = smr.summ_lex_rank(iteration = 1000, D = 0.85)\n",
    "lexrank_g = smr.show_summary(rank, n = 3)\n",
    "\n",
    "print_in_lines(lexrank_g, line_length = 100)\n",
    "\n",
    "rank = smr.summ_ext_LSA(embeddings = False, model = 'all-MiniLM-L6-v2')\n",
    "lsa_g = smr.show_summary(rank, n = 3)\n",
    "print_in_lines(lsa_g, line_length = 100)\n",
    "\n",
    "\n",
    "rank = smr.summ_ext_KL(n = 3)\n",
    "kl_g = smr.show_summary(rank, n = 3)\n",
    "print_in_lines(kl_g, line_length = 100)\n",
    "\n",
    "bart_g = smr.summ_ext_bart(model = 'facebook/bart-large-cnn', max_len = 250)\n",
    "print_in_lines(bart_g, line_length = 100)\n",
    "\n",
    "# generated_summary = smr.summ_ext_t5(model = 't5-base', min_len = 30, max_len = 500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "start a job hunt now so that you can find something you love . if you're unhappy in your degree\n",
      "program, maybe you need to change fields or try something different altogether. general sadness,\n",
      "anxiety, fatigue, brain forgetfulness, and irritability are signs of depression.\n"
     ]
    }
   ],
   "source": [
    "t5_g = smr.summ_ext_t5(model = 't5-base', min_len = 30, max_len = 500)\n",
    "print_in_lines(t5_g, line_length = 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sentencepiece"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "╒═══════════╤══════════════════╤══════════════════╤══════════════════╤══════════════════╤═════════════════╤═════════════════╕\n",
      "│ Method    │ rouge-1-F1/P/R   │ rouge-2-F1/P/R   │ rouge-l-F1/P/R   │ rouge-s-F1/P/R   │ bleu-F1/P/R     │ meteor-F1/P/R   │\n",
      "╞═══════════╪══════════════════╪══════════════════╪══════════════════╪══════════════════╪═════════════════╪═════════════════╡\n",
      "│ text_rank │ 29.41/20.83/50.0 │ 6.06/4.17/11.11  │ 28.57/20.0/50.0  │ 4.26/4.26/4.26   │ 9.13/-1.0/-1.0  │ 43.48/-1.0/-1.0 │\n",
      "├───────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┼─────────────────┼─────────────────┤\n",
      "│ lexrank   │ 13.33/10.0/20.0  │ 6.25/4.35/11.11  │ 11.76/8.33/20.0  │ 2.27/2.27/2.27   │ 6.02/-1.0/-1.0  │ 17.54/-1.0/-1.0 │\n",
      "├───────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┼─────────────────┼─────────────────┤\n",
      "│ lsa       │ 30.77/25.0/40.0  │ 0.0/0.0/0.0      │ 23.08/18.75/30.0 │ 3.45/3.45/3.45   │ 25.0/-1.0/-1.0  │ 37.74/-1.0/-1.0 │\n",
      "├───────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┼─────────────────┼─────────────────┤\n",
      "│ bart      │ 34.48/26.32/50.0 │ 7.41/5.56/11.11  │ 34.48/26.32/50.0 │ 5.71/5.71/5.71   │ 12.09/-1.0/-1.0 │ 45.87/-1.0/-1.0 │\n",
      "├───────────┼──────────────────┼──────────────────┼──────────────────┼──────────────────┼─────────────────┼─────────────────┤\n",
      "│ t5        │ 13.33/10.0/20.0  │ 7.14/5.26/11.11  │ 13.33/10.0/20.0  │ 2.7/2.7/2.7      │ 7.25/-1.0/-1.0  │ 18.18/-1.0/-1.0 │\n",
      "╘═══════════╧══════════════════╧══════════════════╧══════════════════╧══════════════════╧═════════════════╧═════════════════╛\n"
     ]
    }
   ],
   "source": [
    "# ROUGE N\n",
    "from functools import partial\n",
    "\n",
    "def calculate_scores(generated_summary, reference_summary):\n",
    "    # ROUGE 1, ROUGE 2, ROUGE L, ROUGE S, BLEU, METEOR\n",
    "    metric_funcs = {\n",
    "        'rouge-1': partial(smr.rouge_N, n = 1),\n",
    "        'rouge-2': partial(smr.rouge_N, n = 2),\n",
    "        'rouge-l':smr.rouge_L,\n",
    "        'rouge-s': partial(smr.rouge_S, skip_distance = 2),\n",
    "        'bleu': partial(smr.bleu, n = 4),\n",
    "        'meteor': smr.meteor,\n",
    "    }\n",
    "\n",
    "    score_mapper = {}\n",
    "\n",
    "    for metric, func in metric_funcs.items():\n",
    "        score = func(generated_summary = generated_summary, reference_summary = reference_summary)\n",
    "\n",
    "        norm = lambda score: round(score * 100, 2)\n",
    "\n",
    "        if isinstance(score, tuple):\n",
    "            f1, precision, recall = score\n",
    "            score_mapper[metric] = [norm(f1), norm(precision), norm(recall)]\n",
    "            # {\n",
    "            #     'f1': f1, 'precision': precision, 'recall': recall,\n",
    "            # }\n",
    "        else:\n",
    "            score = norm(score)\n",
    "            score_mapper[metric] = [score, -1, -1]\n",
    "\n",
    "    return score_mapper\n",
    "\n",
    "def get_avg_scores(score_mappers):\n",
    "    avg_mapper = {}\n",
    "    for score_mapper in score_mappers:\n",
    "        for metric, scores in score_mapper.items():\n",
    "            avg_mapper[metric] = avg_mapper.get(metric, [0, 0, 0])\n",
    "            for index, score in enumerate(scores):\n",
    "                avg_mapper[metric][index] += score\n",
    "\n",
    "    for metric, scores in avg_mapper.items():\n",
    "        for index, score in enumerate(scores):\n",
    "            avg_mapper[metric][index] = score / len(score_mappers)\n",
    "\n",
    "    return avg_mapper\n",
    "\n",
    "def compose_row(generated_summaries, reference_summaries, method_name):\n",
    "    \n",
    "    score_mappers = [calculate_scores(g, r) for g, r in zip(generated_summaries, reference_summaries)]\n",
    "    avg_mapper = get_avg_scores(score_mappers)\n",
    "\n",
    "    values = []; headers = ['Method']\n",
    "    for metric, scores in avg_mapper.items():\n",
    "        if len(scores) == 1:\n",
    "            value = scores[0]\n",
    "        else:\n",
    "            metric = f'{metric}-F1/P/R'\n",
    "            f1, p, r = scores\n",
    "            value = f'{f1}/{p}/{r}'\n",
    "        values.append(value)\n",
    "        headers.append(metric)\n",
    "    \n",
    "    row = [method_name, *values]\n",
    "    return row, headers\n",
    "\n",
    "row1, headers = compose_row([generated_summary, generated_summary], [reference_summary, reference_summary], 'text_rank')\n",
    "row2, headers = compose_row([lexrank_g], [reference_summary], 'lexrank')\n",
    "row3, headers = compose_row([lsa_g], [reference_summary], 'lsa')\n",
    "row_bart, headers = compose_row([bart_g], [reference_summary], 'bart')\n",
    "row_t5, headers = compose_row([t5_g], [reference_summary], 't5')\n",
    "table = tabulate([row1, row2, row3, row_bart, row_t5], headers = headers, tablefmt = 'fancy_grid')\n",
    "print(table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "╒══════════════════════╤═════════╕\n",
      "│ Evaluation Metric    │   Score │\n",
      "╞══════════════════════╪═════════╡\n",
      "│ Rouge N (F1 Score):  │    0.06 │\n",
      "├──────────────────────┼─────────┤\n",
      "│ Rouge N (Precision): │    0.04 │\n",
      "├──────────────────────┼─────────┤\n",
      "│ Rouge N (Recall):    │    0.11 │\n",
      "├──────────────────────┼─────────┤\n",
      "│ Rouge L (F1 Score):  │    0.29 │\n",
      "├──────────────────────┼─────────┤\n",
      "│ Rouge L (Precision): │    0.2  │\n",
      "├──────────────────────┼─────────┤\n",
      "│ Rouge L (Recall):    │    0.5  │\n",
      "├──────────────────────┼─────────┤\n",
      "│ Rouge S (F1 Score):  │    0.04 │\n",
      "├──────────────────────┼─────────┤\n",
      "│ Rouge S (Precision): │    0.04 │\n",
      "├──────────────────────┼─────────┤\n",
      "│ Rouge S (Recall):    │    0.04 │\n",
      "├──────────────────────┼─────────┤\n",
      "│ BLEU Score:          │    0.09 │\n",
      "├──────────────────────┼─────────┤\n",
      "│ METEOR Score:        │    0.43 │\n",
      "╘══════════════════════╧═════════╛\n"
     ]
    }
   ],
   "source": [
    "# ROUGE N\n",
    "f1_N, precision_N, recall_N = smr.rouge_N(generated_summary = generated_summary, reference_summary = reference_summary, n = 2)\n",
    "rouge_N_scores              = [['Rouge N (F1 Score):', round(f1_N, 2)], ['Rouge N (Precision):', round(precision_N, 2)], ['Rouge N (Recall):', round(recall_N, 2)]]\n",
    "\n",
    "# ROUGE L\n",
    "f1_L, precision_L, recall_L = smr.rouge_L(generated_summary = generated_summary, reference_summary = reference_summary)\n",
    "rouge_L_scores              = [['Rouge L (F1 Score):', round(f1_L, 2)], ['Rouge L (Precision):', round(precision_L, 2)], ['Rouge L (Recall):', round(recall_L, 2)]]\n",
    "\n",
    "# ROUGE S\n",
    "f1_S, precision_S, recall_S = smr.rouge_S(generated_summary = generated_summary, reference_summary = reference_summary, skip_distance = 2)\n",
    "rouge_S_scores              = [['Rouge S (F1 Score):', round(f1_S, 2)], ['Rouge S (Precision):', round(precision_S, 2)], ['Rouge S (Recall):', round(recall_S, 2)]]\n",
    "\n",
    "# BLEU\n",
    "score_b                     = smr.bleu(generated_summary = generated_summary, reference_summary = reference_summary, n = 4)\n",
    "bleu_score                  = [['BLEU Score:', round(score_b, 2)]]\n",
    "\n",
    "# METEOR\n",
    "score_m                     = smr.meteor(generated_summary = generated_summary, reference_summary = reference_summary)\n",
    "meteor_score                = [['METEOR Score:', round(score_m, 2)]]\n",
    "\n",
    "# Table\n",
    "table_data = rouge_N_scores + rouge_L_scores + rouge_S_scores + bleu_score + meteor_score\n",
    "table      = tabulate(table_data, headers = ['Evaluation Metric', 'Score'], tablefmt = 'fancy_grid')\n",
    "print(table)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "wsln",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
